
import os

wd = os.getcwd()

# Imports
from torch import nn
import numpy as np
from collections import OrderedDict
import transformers
import torch
import random
from torch.nn import functional as F
from torch.optim.lr_scheduler import LambdaLR
from transformers import AutoTokenizer
from transformers import AdamW
from datasets import load_dataset
import sys
import argparse
from typing import Tuple
from torch.utils.data.dataloader import DataLoader

import csv
import pickle
import pandas as pd

from utils_log import Log

from data_prep import create_dictionary_of_eval_datasets, prepare_all_dataloaders
from models import LogicModel

print("\n \n PID is",os.getpid(), "\n \n")

def get_args():

  parser = argparse.ArgumentParser(description="Training model parameters")

  # Arguments for modelling different scenarios
  parser.add_argument("--model_type", type=str, default="bert-base-uncased", 
          help="Model to be used")
  parser.add_argument("--random_seed", type=int, default=42, 
          help="Choose the random seed")
  parser.add_argument("--name_id", type=str, default="default",
          help="Name used for saved model and log file")

  # Arguments for model training
  parser.add_argument("--epochs", type=int, default=2, 
          help="Number of epochs for training")
  parser.add_argument("--learning_rate", type=float, default=5e-6, 
          help="Choose learning rate")
  parser.add_argument("--weight_decay", type=float, default=0.01,
          help="Weight decay for AdamW")

  # Reducing training data
  parser.add_argument("--reduce", type=int, default=0, 
          help="Reduce dataset or not")
  parser.add_argument("--reduce_number", type=int, default=1000,
          help="Number of observations to reduce to")

  # Evaluate-only options
  parser.add_argument("--eval_only", type=int, default=0,
          help="Only evaluate without training")
  parser.add_argument("--model_file", type=str, default='saved_model.pt',
          help="File to load model from if eval_only mode")

  # Learning schedule arguments
  parser.add_argument("--linear_schedule", type=int, default=1, 
          help="To use linear schedule with warm up or not")
  parser.add_argument("--warmup_epochs", type=int, default=1, 
          help="Warm up period")
  parser.add_argument("--warmdown_epochs", type=int, default=1, 
          help="Warm down period")

  # Train data arguments
  parser.add_argument("--save_train_data", type=int, default=0,
          help="If we save a CSV of the training data")
  parser.add_argument("--load_train_data", type=int, default=1,
          help="If we load a CSV of the training data")
  
  parser.add_argument("--fact_version", type=str, default='v1',
          help="'v1', 'v2', 'both', 'both_add', 'split', 'split_coref', \
                  'extra', 'extra_v2', 'extended' or 'extended_v2'")
  
  # Save and train arguments
  parser.add_argument("--load_model", type=int, default=0,
          help="Load a baseline model")
  parser.add_argument("--load_logic_model", type=int, default=0,
          help="Load a previously saved logic model")

  parser.add_argument("--load_id", type=str, default="",
          help="Location of model being loaded \
                  (if loading either baseline or logic_model)")

  parser.add_argument("--h_facts", type=int, default=0,
          help="If we also use hypothesis specific facts")
  parser.add_argument("--sent_loss_mult", type=float, default=0.9,
          help="Multiplier for sentence-level loss")

  params, _ = parser.parse_known_args()

  return params

def set_seed(seed_value: int) -> None:
    """
    Set seed for reproducibility.

    Args:
        seed_value: chosen random seed
    """
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

def find_logic_predictions(
        att_unnorm_cont: torch.tensor, 
        att_unnorm_ent: torch.tensor, 
        dataset_name: str) -> int:

    """
    Find logics sentence predictions based on any spans in the sentence

    Args:
        att_unnorm_cont: output from contradiction detection attention layer
        att_unnorm_ent: output from entailment detection attention layer
        dataset_name: dataset name

    Returns:
        pred: logic prediction for NLI sentence pair
    """
    # First check if we have contradiction or not
    if torch.max(att_unnorm_cont) > 0.5:
        pred = 2
    elif torch.max(att_unnorm_ent) > 0.5:
        pred = 0
    else:
        pred = 1
    
    return pred


@torch.no_grad()
def evaluate(
        epoch: int,
        dataset_name: str, 
        dataloader_eval: Tuple[DataLoader, dict, dict, dict]) -> None:
    """
    Evaluates NLI logic model

    Args:
        dataset_name: description of the evaluation dataset
        dataloader_eval: dataloader and lookup dictionaries..
            for recalling spans, eSNLI spans
    """

    logic_model.encoder.eval()
    logic_model.attention_ent.eval_()
    logic_model.attention_cont.eval_()

    correct_logic_att, total =  0, 0

    # Stats for evaluating span performance
    dataloader = dataloader_eval
    
    for i, batch in enumerate(dataloader):

        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = logic_model(batch)

        pred_logic_att = find_logic_predictions(
                outputs['cont']['att_unnorm'], 
                outputs['ent']['att_unnorm'],
                dataset_name)

        total += 1

        # The instance label is the same for each fact in a minibatch
        assert batch['label'][0] == batch['label'][-1]

        if pred_logic_att == batch['label'][0].item():
            correct_logic_att = correct_logic_att + 1

    model_log.msg(["Total accuracy (logic att):" + \
            str(round(correct_logic_att/total, 4))])

    model_log.msg(["Total correct & total:" + \
            str(round(correct_logic_att, 4)) + " & " + str(round(total, 4))])

    return correct_logic_att
    
def get_att_loss(
        outputs: dict, 
        class_str: str, 
        desired_label: int,
        sent_loss: torch.tensor, 
        fact_loss: torch.tensor) \
                -> (torch.tensor, torch.tensor, torch.tensor):
    """
    Finds the the sentence loss, and additional loss term
    
    Args:
        outputs: dictionary of model outputs
        class_str: 'ent' or 'cont' for different attention layers
        sent_loss: sentence loss for observation so far
        additional_loss_term: additional loss term for observation so far

    Returns:
        sent_loss: updated sentence loss for observation
        additional_loss_term: updated additional loss for observation
    """

    # Sent loss
    sent_loss += (outputs[class_str]['sent_output'] \
                            - desired_label)**2

    # Additional loss term
    fact_loss += torch.square(
            torch.max(outputs[class_str]['att_unnorm']) \
                    - desired_label)

    return sent_loss, fact_loss

def train() -> None:

    # Our train dataloader
    dataloader_train = train_dataloader['anli_train']
    best_dev = 0
    best_epoch = 0
    
    # When we reduce the training data size, we use early stopping
    if params.reduce:
        print("Saving model before training")
        torch.save(logic_model.state_dict(),
                    os.getcwd() + "/savedmodel/saved_model_" \
                            + name_id + '.pt')

    for epoch in range(params.epochs):

        logic_model.encoder.train()
        logic_model.attention_ent.train()
        logic_model.attention_cont.train()

        for i, batch in enumerate(dataloader_train):

            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = logic_model(batch)

            e_sup_value = outputs['ent']['label'].item()
            c_sup_value = outputs['cont']['label'].item()
        
            # We calculate losses
            sent_loss = torch.tensor([0.0]).to(device)
            fact_loss = torch.tensor([0.0]).to(device)

            # Update the loss from the entailment and cont attention layers
            if outputs['true_label'] == 1 or outputs['true_label'] == 0:
                sent_loss, fact_loss = get_att_loss(
                            outputs,
                            'ent',
                            e_sup_value,
                            sent_loss,
                            fact_loss)
            
            sent_loss, fact_loss = get_att_loss(
                        outputs,
                        'cont',
                        c_sup_value,
                        sent_loss,
                        fact_loss)

            loss = params.sent_loss_mult * sent_loss + fact_loss

            loss.backward(retain_graph=True)
            
            if outputs['true_label'] == 1 or outputs['true_label'] == 0:
                optimizer_ent.step()

            optimizer_contradiction.step()
            optimizer_encoder.step()

            if params.linear_schedule:
                if not params.reduce:
                    schedule_encoder.step()
                    schedule_ent.step()
                    schedule_cont.step()

            optimizer_encoder.zero_grad()
            optimizer_ent.zero_grad()
            optimizer_contradiction.zero_grad()

        dev_acc = evaluate_each_epoch(epoch)
    
        if params.reduce:
            
            if dev_acc > best_dev:
                best_dev = dev_acc
                best_epoch = epoch + 1
                torch.save(logic_model.state_dict(),
                            os.getcwd() + "/savedmodel/saved_model_" \
                                    + name_id + '.pt')
    
    if params.reduce:

        model_log.msg(["The best epoch was epoch: " + str(best_epoch)])
        
        logic_model.load_state_dict(torch.load(
        os.getcwd() + "/savedmodel/saved_model_" \
                                    + name_id + '.pt'))

        _ = evaluate_each_epoch(epoch, True)

def evaluate_each_epoch(epoch: int, final_eval=False) -> None:
    """
    Evaluates model on each dataset after each epoch

    Args:
        epoch: epoch number (starting at 0)
    """
    
    if not final_eval:
        model_log.msg(["Epoch:" + str(epoch+1)])

        dev_scores = 0

        for dataset_name, dataset in eval_dataloaders.items():

            # For reduced training data, evaluate only on dev sets in training
            # .. evaluating on all test sets after training has finished
            if (not params.reduce) or (dataset_name[:6] == 'double' and \
                    dataset_name[len(dataset_name)-6:len(dataset_name)-3] == 'dev'):
                model_log.msg(["Dataset: " + dataset_name])

                # dev_scores used for early stopping with reduced dataset
                dev_score = evaluate(epoch, dataset_name, dataset)
                dev_scores = dev_scores + dev_score
    else:

        model_log.msg(["Final evaluation"])

        for dataset_name, dataset in eval_dataloaders.items():
            model_log.msg(["Dataset: " + dataset_name])
            _ = evaluate(epoch, dataset_name, dataset)
            dev_scores = 0

    return dev_scores

def create_lr_schedules():

    if params.linear_schedule:
        
        # No linear schedule for reduced training data
        if not params.reduce:
            schedule_encoder = LambdaLR(optimizer_encoder, lr_lambda_enc)
            schedule_ent = LambdaLR(optimizer_ent, lr_lambda)
            schedule_cont = LambdaLR(optimizer_contradiction, lr_lambda)

            return schedule_encoder, schedule_ent, schedule_cont

    return None, None, None

def lr_lambda_enc(current_step: int) -> float:

    num_warmup_steps_opt = num_warmup_steps
    num_training_steps_opt = num_training_steps

    if current_step < num_warmup_steps_opt:
        return float(current_step) / float(max(1, num_warmup_steps_opt))

    return max(
            0.0, float(num_training_steps_opt - current_step) / float(
                max(1, num_training_steps_opt - num_warmup_steps_opt)))


def lr_lambda(current_step: int) -> float:

    num_warmup_steps_opt = num_warmup_steps
    num_training_steps_opt = num_training_steps

    if current_step < num_warmup_steps_opt:
        return float(current_step) / float(max(1, num_warmup_steps_opt))

    return max(
            0.0, float(num_training_steps_opt - current_step) / float(
                max(1, num_training_steps_opt - num_warmup_steps_opt)))


def get_loaded_state():
    
    loaded_state_dict = torch.load(
            os.getcwd() + "/savedmodel/" + params.load_id + '.pt')

    if params.model_type[0:9] == 'microsoft':
        
        keys_for_encoder_params = {}
        keys_to_remove = []

        for key, value in loaded_state_dict.items():
            if key[:8] == 'deberta.':
                keys_for_encoder_params[key] = key[8:]
            elif  key[:11] == 'classifier.':
                keys_to_remove.append(key)
            elif  key[:7] == 'pooler.':
                  keys_to_remove.append(key)

        for _ in range(len(loaded_state_dict)):
            k, v = loaded_state_dict.popitem(False)
            loaded_state_dict[keys_for_encoder_params[k] \
                    if k in keys_for_encoder_params.keys() else k] = v

        for key in keys_to_remove:
            del loaded_state_dict[key]

    elif params.model_type[0:4] == 'bert':
        keys_for_encoder_params = {}
        keys_to_remove = []

        for key, value in loaded_state_dict.items():
            if key[:5] == 'bert.':
                keys_for_encoder_params[key] = key[5:]
            elif  key[:11] == 'classifier.':
                keys_to_remove.append(key)

        for _ in range(len(loaded_state_dict)):
            k, v = loaded_state_dict.popitem(False)
            loaded_state_dict[keys_for_encoder_params[k] \
                    if k in keys_for_encoder_params.keys() else k] = v

        for key in keys_to_remove:
            del loaded_state_dict[key]
    
    return loaded_state_dict


if __name__ == '__main__':

    params = get_args()

    params.reduce = bool(params.reduce)
    params.linear_schedule = bool(params.linear_schedule)
    params.eval_only = bool(params.eval_only)
    params.save_train_data = bool(params.save_train_data)
    params.load_train_data = bool(params.load_train_data)
    params.load_model = bool(params.load_model)
    params.load_logic_model = bool(params.load_logic_model)
    params.h_facts = bool(params.h_facts)
    params.baseline = False

    if params.eval_only:
        assert params.load_logic_model

    params.train_splits = ['train_r1', 'train_r2', 'train_r3']

    print(params)

    if params.name_id == 'default':
        name_id =  str(os.getpid())
    else:
        name_id = params.name_id
    
    if params.reduce:
        assert params.load_train_data == 0 and params.save_train_data == 0, \
                "Can only load (or save) full training data"

    # Logging file
    log_file_name = 'log_logic_model_' + name_id + '.txt'
    model_log = Log(log_file_name, params)

    # Set CUDAS to GPU
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    set_seed(params.random_seed)

    # Create folder for saving models
    if not os.path.exists('savedmodel'):
        os.makedirs('savedmodel')

    # We create a dictionary of the Huggingface datasets to be OOD datasets
    eval_data_list = create_dictionary_of_eval_datasets(
                params)

    tokenizer = AutoTokenizer.from_pretrained(
            params.model_type,
            truncation=False)

    # We create dataloaders for eSNLI and HuggingFace datasets
    train_dataloader, eval_dataloaders = prepare_all_dataloaders(
                    eval_data_list,
                    params,
                    tokenizer)

    if params.model_type == 'microsoft/deberta-large' or \
            params.model_type == 'microsoft/deberta-xlarge' or \
            params.model_type == 'microsoft/deberta-v3-large':
        dim = 1024
    else:
        dim = 768

    logic_model = LogicModel(
            dim, 
            params.model_type)

    logic_model.to(device)

    # Check if we load a previous model
    if params.load_model:

        loaded_state_dict = get_loaded_state()

        if params.model_type[0:4] == 'bert':

            logic_model.encoder.bert.load_state_dict(
                loaded_state_dict)

        elif params.model_type[0:9] == 'microsoft':

            logic_model.encoder.deberta.load_state_dict(
                    loaded_state_dict)


    if params.load_logic_model:

         loaded_state_dict = torch.load(
            os.getcwd() + "/savedmodel/" + params.load_id + '.pt')

         logic_model.load_state_dict(
                loaded_state_dict)

    # We create our optimizers
    optimizer_encoder = AdamW(
            list(logic_model.encoder.parameters()),
            lr=params.learning_rate,
            weight_decay=params.weight_decay)

    optimizer_ent = AdamW(
            list(logic_model.attention_ent.parameters()),
            lr=params.learning_rate,
            weight_decay=params.weight_decay)

    optimizer_contradiction = AdamW(
            list(logic_model.attention_cont.parameters()),
            lr=params.learning_rate,
            weight_decay=params.weight_decay)

    # We create our learning schedules
    dataloader_name = 'anli_train'

    num_warmup_steps = len(
            train_dataloader[dataloader_name])*params.warmup_epochs
    warm_down_steps = len(
            train_dataloader[dataloader_name])*params.warmdown_epochs
    num_training_steps = num_warmup_steps + warm_down_steps

    print("warm up starts:", num_warmup_steps)

    schedule_encoder, schedule_ent, schedule_cont = create_lr_schedules()

    if params.eval_only:

        logic_model.load_state_dict(torch.load(
        os.getcwd() + "/savedmodel/" + params.load_id + '.pt'))
        _ = evaluate_each_epoch(0, True)

        if params.evaluate_facts:
            for dataset_name, dataset in fact_eval_dataloaders.items():
                model_log.msg(["Fact evaluation for: " + dataset_name])
                evaluate_facts(
                0,
                dataset_name,
                dataset)

    else:

        train()
        print("All done")
        
        if not params.reduce:
            torch.save(logic_model.state_dict(),
                                os.getcwd() + "/savedmodel/saved_model_" \
                                        + name_id + '.pt')
            print("All models saved")
